{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iJabaHf5n8Jp"
      },
      "outputs": [],
      "source": [
        "from transformers import DetrImageProcessor, DetrForObjectDetection\n",
        "import torch\n",
        "from PIL import Image\n",
        "import requests\n",
        "\n",
        "processor = DetrImageProcessor.from_pretrained(\"facebook/detr-resnet-50\")\n",
        "model = DetrForObjectDetection.from_pretrained(\"facebook/detr-resnet-50\")"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "url = \"http://images.cocodataset.org/val2017/000000439715.jpg\"\n",
        "image = Image.open(requests.get(url, stream=True).raw)\n"
      ],
      "metadata": {
        "id": "m8E_jNAKoEJo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "inputs = processor(images=image, return_tensors=\"pt\")\n",
        "outputs = model(**inputs)"
      ],
      "metadata": {
        "id": "GOXvc2YToJ9x"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "target_sizes = torch.tensor([image.size[::-1]])\n",
        "results = processor.post_process_object_detection(\n",
        " outputs, target_sizes=target_sizes, threshold=0.9)[0]"
      ],
      "metadata": {
        "id": "eRTVpCQzoKi4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for score, label, box in zip(results[\"scores\"], results[\"labels\"], results[\"boxes\"]):\n",
        "    box = [round(i, 2) for i in box.tolist()]\n",
        "    print(\n",
        "        f\"Detected {model.config.id2label[label.item()]} \\\n",
        " with confidence \"\n",
        "        f\"{round(score.item(), 3)} at location {box}\"\n",
        "    )\n"
      ],
      "metadata": {
        "id": "gDF52egfoM3h"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "COLORS = [\n",
        "    [0.0, 0.5, 0.8],\n",
        "    [0.9, 0.3, 0.1],\n",
        "    [0.9, 0.6, 0.1],\n",
        "    [0.4, 0.1, 0.5],\n",
        "    [0.4, 0.6, 0.1],\n",
        "    [0.3, 0.7, 0.9],\n",
        "]\n",
        "\n",
        "\n",
        "def visualize_prediction(pil_img, output_dict, threshold):\n",
        "    keep = output_dict[\"scores\"] > threshold\n",
        "    boxes = output_dict[\"boxes\"][keep].tolist()\n",
        "    scores = output_dict[\"scores\"][keep].tolist()\n",
        "    labels = output_dict[\"labels\"][keep].tolist()\n",
        "    labels = [model.config.id2label[x] for x in labels]\n",
        "    plt.figure(figsize=(8, 5))\n",
        "    plt.imshow(pil_img)\n",
        "    ax = plt.gca()\n",
        "    colors = COLORS * 100\n",
        "    for score, (xmin, ymin, xmax, ymax), label, color in zip(\n",
        "        scores, boxes, labels, colors\n",
        "    ):\n",
        "        ax.add_patch(\n",
        "            plt.Rectangle(\n",
        "                (xmin, ymin),\n",
        "                xmax - xmin,\n",
        "                ymax - ymin,\n",
        "                fill=False,\n",
        "                color=color,\n",
        "                linewidth=3,\n",
        "            )\n",
        "        )\n",
        "        ax.text(xmin, ymin, label, fontsize=8, bbox=dict(facecolor=\"yellow\", alpha=0.5))\n",
        "        plt.axis(\"off\")\n",
        "    plt.show()\n"
      ],
      "metadata": {
        "id": "o0euTYtMoQt1"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "visualize_prediction(image, results, 0.9)\n"
      ],
      "metadata": {
        "id": "_y6hjplIoeUA"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}